import subprocess
import time
import re
import matplotlib.pyplot as plt
import numpy as np

# 存储所有NPU的显存占用数据
npu_memory_data = [[] for _ in range(8)]

def get_npu_memory_usage(npu_id):
    """获取指定NPU的显存占用情况"""
    command = ['npu-smi', 'info', '-t', 'usages', '-i', str(npu_id)]
    output = subprocess.check_output(command)
    output = output.decode('utf-8')
    
    # 正则表达式匹配HBM Usage Rate(%)和HBM Capacity(MB)
    hbm_usage_rate = re.search(r'HBM Usage Rate.*: (\d+)', output)
    hbm_capacity = re.search(r'HBM Capacity.*: (\d+)', output)
    if hbm_usage_rate and hbm_capacity:
        hbm_usage_rate = int(hbm_usage_rate.group(1))
        hbm_capacity = int(hbm_capacity.group(1))
        #print(hbm_usage_rate, hbm_capacity)
        return (hbm_usage_rate * hbm_capacity) / 100  # 计算显存占用值
    else:
        return 0

def monitor_npu_memory(interval=1):
    """监控NPU显存占用"""
    while True:
        for i in range(8):
            memory_usage = get_npu_memory_usage(i)
            npu_memory_data[i].append(memory_usage)
        time.sleep(interval)

def plot_npu_memory_data():
    """绘制显存占用曲线图"""
    max_len = max(len(data) for data in npu_memory_data)
    times = np.arange(len(npu_memory_data[0]))
    plt.figure(figsize=(12, 6))
    for i in range(8):
        npu_data = np.array(npu_memory_data[i] + [np.nan] * (max_len - len(npu_memory_data[i])))
        plt.plot(times, npu_data, label=f'NPU {i}')
    plt.xlabel('Time (s)')
    plt.ylabel('Memory Usage (MB)')
    plt.title('NPU Memory Usage Over Time')
    plt.legend()
    plt.grid(True)
    plt.savefig('npu_memory_usage.png')
    plt.show()

if __name__ == '__main__':
    try:
        print("Monitoring NPU memory usage. Press Ctrl+C to stop.")
        monitor_npu_memory()
    except KeyboardInterrupt:
        print("\nMonitoring stopped.")
        # 将数据保存到文件
        with open('npu_memory_data.txt', 'w') as f:
            for i in range(8):
                f.write(f'NPU {i} Memory Usage: ')
                for data in npu_memory_data[i]:
                    f.write(f'{data} ')
                f.write('\n')
        plot_npu_memory_data()
        print("Memory usage data has been plotted and saved to npu_memory_usage.png and npu_memory_data.txt")
